{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "sX4M-kbQZtmS",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 215.0
    },
    "outputId": "ce4ac3be-dc76-4ae3-8db3-5df57944a5a2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0100 cost = 0.017576\n",
      "Epoch: 0200 cost = 0.003831\n",
      "Epoch: 0300 cost = 0.001706\n",
      "Epoch: 0400 cost = 0.000970\n",
      "Epoch: 0500 cost = 0.000627\n",
      "Epoch: 0600 cost = 0.000439\n",
      "Epoch: 0700 cost = 0.000325\n",
      "Epoch: 0800 cost = 0.000250\n",
      "Epoch: 0900 cost = 0.000199\n",
      "Epoch: 1000 cost = 0.000162\n",
      "['mak', 'nee', 'coa', 'wor', 'lov', 'hat', 'liv', 'hom', 'has', 'sta'] -> ['e', 'd', 'l', 'd', 'e', 'e', 'e', 'e', 'h', 'r']\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "  code by Tae Hwan Jung(Jeff Jung) @graykode\n",
    "'''\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "\n",
    "dtype = torch.FloatTensor\n",
    "\n",
    "char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']\n",
    "word_dict = {n: i for i, n in enumerate(char_arr)}\n",
    "number_dict = {i: w for i, w in enumerate(char_arr)}\n",
    "n_class = len(word_dict) # number of class(=number of vocab)\n",
    "\n",
    "seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']\n",
    "\n",
    "# TextLSTM Parameters\n",
    "n_step = 3\n",
    "n_hidden = 128\n",
    "\n",
    "def make_batch(seq_data):\n",
    "    input_batch, target_batch = [], []\n",
    "\n",
    "    for seq in seq_data:\n",
    "        input = [word_dict[n] for n in seq[:-1]] # 'm', 'a' , 'k' is input\n",
    "        target = word_dict[seq[-1]] # 'e' is target\n",
    "        input_batch.append(np.eye(n_class)[input])\n",
    "        target_batch.append(target)\n",
    "\n",
    "    return Variable(torch.Tensor(input_batch)), Variable(torch.LongTensor(target_batch))\n",
    "\n",
    "class TextLSTM(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TextLSTM, self).__init__()\n",
    "\n",
    "        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden)\n",
    "        self.W = nn.Parameter(torch.randn([n_hidden, n_class]).type(dtype))\n",
    "        self.b = nn.Parameter(torch.randn([n_class]).type(dtype))\n",
    "\n",
    "    def forward(self, X):\n",
    "        input = X.transpose(0, 1)  # X : [n_step, batch_size, n_class]\n",
    "\n",
    "        hidden_state = Variable(torch.zeros(1, len(X), n_hidden))   # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
    "        cell_state = Variable(torch.zeros(1, len(X), n_hidden))     # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
    "\n",
    "        outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))\n",
    "        outputs = outputs[-1]  # [batch_size, n_hidden]\n",
    "        model = torch.mm(outputs, self.W) + self.b  # model : [batch_size, n_class]\n",
    "        return model\n",
    "\n",
    "input_batch, target_batch = make_batch(seq_data)\n",
    "\n",
    "model = TextLSTM()\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "output = model(input_batch)\n",
    "\n",
    "# Training\n",
    "for epoch in range(1000):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    output = model(input_batch)\n",
    "    loss = criterion(output, target_batch)\n",
    "    if (epoch + 1) % 100 == 0:\n",
    "        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "inputs = [sen[:3] for sen in seq_data]\n",
    "\n",
    "predict = model(input_batch).data.max(1, keepdim=True)[1]\n",
    "print(inputs, '->', [number_dict[n.item()] for n in predict.squeeze()])"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "TextLSTM-Torch.ipynb",
   "version": "0.3.2",
   "provenance": [],
   "collapsed_sections": []
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "accelerator": "GPU"
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
